跳到主要内容

PyTorch 加载数据集

使用 transforms 加载数据集

transforms 是 PyTorch 的 torchvision 库中的一个非常有用的模块,它提供了一系列预处理功能,可以在加载数据时直接应用于数据集。这对于图像数据集特别有用,因为你经常需要进行诸如裁剪、归一化、增强等操作。

以下是一个使用 transforms 加载 CIFAR10 数据集的示例:

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# 定义转换操作
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.ToTensor(), # 将 PIL 图像转换为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])

# 加载数据集并应用转换
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

在上述代码中,我们首先定义了一个转换操作序列,然后在加载 CIFAR10 数据集时将其应用。这样,每次从数据集中提取图像时,都会自动应用这些转换操作。

使用 transforms 可以大大简化数据预处理和增强的工作流程,并确保在训练和评估时都使用相同的预处理步骤。

自定义加载训练集

在 PyTorch 中,加载自定义训练集通常涉及以下几个步骤:

  1. 定义数据集类:创建一个继承自 torch.utils.data.Dataset 的类,并实现 __len____getitem__ 方法。
  2. 创建数据集实例:使用定义的数据集类创建数据集的实例。
  3. 创建数据加载器:使用 torch.utils.data.DataLoader 创建一个数据加载器,以便在训练和评估时批量加载数据。

以下是一个简单的示例,展示如何加载自定义数据集:

1. 定义数据集类

首先,我们需要定义一个数据集类,该类继承自 torch.utils.data.Dataset

from torch.utils.data import Dataset
from PIL import Image
import os

class CustomDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.png')]

def __len__(self):
return len(self.image_files)

def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.image_files[idx])
image = Image.open(img_name)
label = self.image_files[idx].split('_')[1].split('.')[0] # 假设标签是文件名中的某部分
if self.transform:
image = self.transform(image)
return image, label

在上述代码中,我们定义了一个 CustomDataset 类,该类从指定的 root_dir 加载图像,并从文件名中提取标签。

2. 创建数据集实例

使用定义的数据集类创建数据集的实例:

import torchvision.transforms as transforms

transform = transforms.Compose([
transforms.Resize((128, 64)),
transforms.ToTensor(),
])

train_dataset = CustomDataset(root_dir='captcha', transform=transform)

3. 创建数据加载器

使用 torch.utils.data.DataLoader 创建数据加载器:

from torch.utils.data import DataLoader

trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

现在,你可以在训练循环中使用 trainloader 来迭代数据:

for batch_data, batch_labels in trainloader:
# 进行训练或评估
pass

这就是在 PyTorch 中加载自定义数据集的基本方法。你可以根据自己的需求进一步扩展和自定义这个过程。

batch_size 的设置

batch_size 是一个参数,用于指定从数据集中一次提取的样本数量,以进行一次迭代的训练或评估。

在深度学习中,我们通常不会一次处理整个数据集,因为这样可能会导致内存不足或计算效率低下。相反,我们将数据集分成多个小批次(batch),每个批次包含一定数量的样本。这种方法称为小批次梯度下降(Mini-batch Gradient Descent)。

具体来说,batch_size=64 意味着每次从 train_dataset 中提取 64 个样本进行训练。这样,网络的权重会在每个批次后更新,而不是在整个数据集上进行一次完整的前向和反向传播后更新。

以下是 batch_size 的一些关键点:

  1. 计算效率:使用小批次可以利用现代硬件(特别是 GPU)的并行处理能力,从而提高计算效率。
  2. 内存使用:较小的批次可以减少内存使用,使得大型模型和数据集可以在有限的内存中进行训练。
  3. 收敛速度:与整批次梯度下降相比,小批次梯度下降通常可以更快地收敛,但可能会在达到最小值时出现震荡。
  4. 泛化性能:由于每次更新都是基于小批次的数据,这为模型提供了一定的随机性,有助于防止过拟合。

选择合适的 batch_size 是一个实验问题,可能需要根据具体的应用和硬件进行调整。

下面举个具体的例子说明

import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# 1. 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 2. 加载数据集
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)

# 假设我们只使用600张图片
subset_dataset = torch.utils.data.Subset(dataset, indices=range(600))

# 3. 创建数据加载器
batch_size = 10
dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True)

# 4. 定义一个简单的模型和损失函数
model = torch.nn.Linear(3 * 32 * 32, 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 5. 训练模型
num_epochs = 10
for epoch in range(num_epochs):
for i, (inputs, labels) in enumerate(dataloader):
# 将输入展平
inputs = inputs.view(inputs.size(0), -1)

# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)

# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 打印每个批次的损失
if (i+1) % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/60], Loss: {loss.item():.4f}")

在上述代码中:

  • 我们首先加载 CIFAR10 数据集并选择其中的600张图片。
  • 使用 batch_size = 10,所以每个 epoch 有60个批次。
  • 我们训练模型10个 epoch,所以整个数据集会被处理10次。
  • 在每个 epoch 中,我们迭代60个批次,对每个批次的数据进行前向和反向传播,并更新模型的权重。